import numpy as np
import copy
import argparse
import pickle

import matplotlib.pyplot as plt
#from scipy import sparse

class AlesiaEnv:
    def __init__(self, players_budget=20, length=5):
        self.players_budget = players_budget
        self.budget_x = players_budget
        self.budget_y = players_budget
        self.x_victories = []
        self.y_victories = []
        self.length = length

    def reset(self):
        self.budget_x = self.players_budget
        self.budget_y = self.players_budget
        self.state = np.floor(self.length / 2) + 1
        return np.array(
            [np.floor(self.length / 2) + 1, self.budget_x, self.budget_y])

    def step(self, action_x, action_y):
        reward = 0.0
        if self.budget_x == 0 and self.budget_y == 0:

            done = True
            reward = 0.0
        else:
            if action_x + 1 >= self.budget_x:

                action_x = self.budget_x - 1

            if action_y + 1 >= self.budget_y:

                action_y = self.budget_y - 1

            self.budget_x -= (action_x + 1)
            self.budget_y -= (action_y + 1)
            done = False
            next_state = self.state + np.int(action_x <= action_y) - np.int(
                action_y <= action_x)
            if next_state == 0 or next_state == self.length + 1:
                if next_state == 0:
                    self.x_victories.append(1)
                    self.y_victories.append(0)
                elif next_state == self.length + 1:
                    self.x_victories.append(0)
                    self.y_victories.append(1)
                done = True
                reward = -1.0 if next_state == 0 else 1.0
            self.state = next_state
        return (np.array([self.state, self.budget_x, self.budget_y]),
                reward,
                done)


def get_state_index(state, players_budget):
    return np.int((state[0] - 1) * (players_budget + 1) ** 2 + state[1] * (
            players_budget + 1) + state[2])


def get_state(config, players_budget):
    state = np.zeros(3)
    state[0] = np.floor(config / (players_budget + 1) ** 2) + 1
    reminder = config - (state[0] - 1) * (players_budget + 1) ** 2
    state[1] = np.floor(reminder / (players_budget + 1))
    reminder -= state[1] * (players_budget + 1)
    state[2] = reminder
    return state


def create_transition_matrix(players_budget, length):
    total_configurations = (players_budget + 1) ** 2 * length
    transition_matrix = np.zeros((total_configurations + 1,
                                  total_configurations + 1,
                                  players_budget,
                                  players_budget))
    n_actions = players_budget
    reward_nash = np.zeros((total_configurations + 1, n_actions, n_actions))
    reward = np.zeros((total_configurations + 1, n_actions))
    counter_neg = 0
    counter_pos = 0
    for config in range(total_configurations):
        grid_position, budget_x, budget_y = get_state(config, players_budget)
        for action_x in range(players_budget):
            for action_y in range(players_budget):
                if action_x >= budget_x - 1:
                    next_budget_x = 0
                    a_x = np.max(budget_x - 1, 0)
                else:
                    next_budget_x = budget_x - action_x - 1
                    a_x = action_x
                if action_y >= budget_y - 1:
                    next_budget_y = 0
                    a_y = np.max(budget_y - 1, 0)
                else:
                    next_budget_y = budget_y - action_y - 1
                    a_y = action_y
                next_grid_position = grid_position + np.int(
                    a_x <= a_y) - np.int(a_y <= a_x)
                next_index = get_state_index(
                    [next_grid_position, next_budget_x, next_budget_y],
                    players_budget)
                if next_grid_position == 0:
                    counter_neg += 1
                    reward_nash[config, action_x, action_y] = - 1.0
                    next_index = -1
                elif next_grid_position == (length + 1):
                    counter_pos += 1
                    reward_nash[config, action_x, action_y] = 1.0

                    next_index = -1
                if next_budget_x == 0 or next_budget_y == 0:
                    next_index = -1
                transition_matrix[next_index, config, action_x, action_y] = 1.0
    transition_matrix[-1, -1, :, :] = 1.0
    return transition_matrix, reward_nash


def create_single_agent_transition_matrix(fixed_policy, players_budget, length,
                                          fixed_player="y"):
    T, reward_nash = create_transition_matrix(players_budget, length)
    P = np.zeros((T.shape[0], T.shape[0], T.shape[3]))
    for c in range(T.shape[0] - 1):
        for a in range(T.shape[2]):
            if fixed_player == "y":
                P[:, c, a] = T[:, c, a, :].dot(fixed_policy[:, c])
            elif fixed_player == "x":
                P[:, c, a] = T[:, c, :, a].dot(fixed_policy[:, c])
    P[-1, -1, :] = 1.0
    return P, reward_nash


def value_iteration(P, r, players_budget, length, tol=1e-10, player="x"):
    n_states = (players_budget + 1) ** 2 * length + 1
    n_actions = players_budget
    v = np.zeros(n_states)
    q = np.zeros((n_states, n_actions))
    while True:
        v_old = np.copy(v)
        for a in range(n_actions):
            q[:, a] = r[:, a] + 0.99 * P[:, :, a].T.dot(v)
        if player == "x":
            v = np.min(q, axis=1)
        elif player == "y":
            v = np.max(q, axis=1)
        # v[-1] = 0
        if np.linalg.norm(v - v_old) < tol:
            break
    return v


def nash_dyn_programming(T, reward_nash, players_budget, length, tol=1e-10):
    n_states = (players_budget + 1) ** 2 * length + 1
    n_actions = players_budget
    v = np.zeros(n_states)
    q = np.zeros((n_states, n_actions, n_actions))
    for _ in range(players_budget):
        v_old = np.copy(v)
        for a in range(n_actions):
            for b in range(n_actions):
                q[:, a, b] = reward_nash[:, a, b] + T[:, :, a, b].T.dot(v)
        for s in range(n_states - 1):
            rps = nash.Game(q[s])
            players = list(rps.support_enumeration())[0]
            v[s] = rps[players[0], players[1]][0]
        # v[-1]=0.0
        if np.linalg.norm(v - v_old) < tol:
            break
    return v


def score_policy(policy, players_budget, length, v_nash, player="x"):
    P, reward_nash = create_single_agent_transition_matrix(policy,
                                                           players_budget,
                                                           length,
                                                           fixed_player=player)
    opponent_player = "x" if player == "y" else "y"
    reward = np.zeros(((players_budget + 1) ** 2 * length + 1, players_budget))
    for j in range(policy.shape[1]):
        for a in range(players_budget):
            if opponent_player == "x":
                reward[j, a] = reward_nash[j, a, :].dot(policy[:, j])
            elif opponent_player == "y":
                reward[j, a] = reward_nash[j, :, a].dot(policy[:, j])
    v = value_iteration(P, reward, players_budget, length,
                        player=opponent_player)
    s = get_state_index(
        [np.floor(length / 2) + 1, players_budget, players_budget],
        players_budget)

    return v[s] - v_nash[s]


def softmax_update(policy, weights):
    zetas = -weights + np.log(policy + 1e-15)
    zetas = zetas - np.max(zetas)
    return np.exp(zetas) / np.sum(np.exp(zetas))


class ReflectedNAC:
    def __init__(self, players_budget, length):
        self.players_budget = players_budget
        self.length = length
        self.policy_x = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget  # Action x States
        self.policy_y = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget  # Action x States
        self.policy_y_bar = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.env = AlesiaEnv(players_budget, length)
        self.critic_x = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # Actions x States
        self.critic_y = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # Actions x States
        self.old_critic_x = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # Actions x States
        self.old_critic_y = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # Actions x States
        self.evaluation_action_critic = np.zeros(
            (players_budget, (players_budget + 1) ** 2 * length))
        self.policy_x_ks = [np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget]
        self.policy_y_ks = [np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget]
        self.epsilon=0.4

    def get_state_index(self, state):

        return np.int(
            (state[0] - 1) * (self.players_budget + 1) ** 2 + state[1] * (
                    self.players_budget + 1) + state[2])

    def V_hat_eval(self, N):
        V_hat = np.random.uniform(
            size=(self.players_budget + 1) ** 2 * self.length)
        eps_greedy_x = self.policy_x_ks[-1] * (1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        eps_greedy_y = self.policy_y_ks[-1] * (1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        state = self.env.reset()
        for n in range(N):
            index_state = self.get_state_index(state)
            beta_n = 1e-2 / (n + 1)
            action_x = np.random.choice(self.players_budget,
                                        p=eps_greedy_x[:, index_state])
            action_y = np.random.choice(self.players_budget,
                                        p=eps_greedy_y[:, index_state])
            new_state, reward, done = self.env.step(action_x, action_y)
            if done:
                V_hat[index_state] -= beta_n * (V_hat[index_state] - reward)
                state = self.env.reset()
            else:
                index_new_state = self.get_state_index(new_state)
                V_hat[index_state] -= beta_n * (
                        V_hat[index_state] - reward - V_hat[
                    index_new_state])
                state = copy.deepcopy(new_state)
        return V_hat

    def critic_eval(self, N, V_hat, player="x"):
        state = self.env.reset()
        eps_greedy_x = self.policy_x*(1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        eps_greedy_y = self.policy_y*(1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        for n in range(N):
            index_state = self.get_state_index(state)
            action_x = np.random.choice(self.players_budget,
                                        p=eps_greedy_x[:, index_state])
            action_y = np.random.choice(self.players_budget,
                                        p=eps_greedy_y[:, index_state])
            new_state, reward, done = self.env.step(action_x, action_y)
            if not done:
                index_new_state = self.get_state_index(new_state)
                if player == "x":
                    beta_n = 1e-1 / (n + 1)
                    self.critic_x[action_x, index_state] -= beta_n * (
                            self.critic_x[action_x, index_state]
                            - reward
                            - V_hat[index_new_state])
                elif player == "y":
                    beta_n = 1e-1 / (n + 1)
                    self.critic_y[action_y, index_state] -= beta_n * (
                            self.critic_y[action_y, index_state]
                            - reward
                            - V_hat[index_new_state])
                state = copy.deepcopy(new_state)
            else:
                if player == "x":
                    beta_n = 1e-1 / (n + 1)
                    self.critic_x[action_x, index_state] -= beta_n * (
                            self.critic_x[action_x, index_state]
                            - reward)
                elif player == "y":
                    beta_n = 1e-1 / (n + 1)
                    self.critic_y[action_y, index_state] -= beta_n * (
                            self.critic_y[action_y, index_state]
                            - reward)
                state = self.env.reset()

    def greedy_step(self, T, N, return_player="x"):
        x_k = np.zeros(
            (self.players_budget, (self.players_budget + 1) ** 2 * self.length))
        y_k = np.zeros(
            (self.players_budget, (self.players_budget + 1) ** 2 * self.length))
        if return_player == "x":
            eta = 5e-2
        elif return_player == "y":
            eta = 5e-2
        for t in range(T):
            self.old_critic_x = copy.deepcopy(self.critic_x)
            self.old_critic_y = copy.deepcopy(self.critic_y)
            V_hat_x = self.V_hat_eval(N)
            V_hat_y = self.V_hat_eval(N)

            self.critic_eval(N, V_hat_x, player="x")
            self.critic_eval(N, V_hat_y, player="y")
            for s in range(self.policy_x.shape[1]):
                self.policy_x[:, s] = softmax_update(self.policy_x[:, s],
                                                     eta * (2 * self.critic_x[:,
                                                                s]
                                                            - self.old_critic_x[
                                                              :, s]))
                self.policy_y[:, s] = softmax_update(self.policy_y[:, s],
                                                     -eta * (2 * self.critic_y[
                                                                 :, s]
                                                             - self.old_critic_y[
                                                               :, s]))
            x_k += self.policy_x / T
            y_k += self.policy_y / T
        if return_player == "x":
            self.policy_x = copy.deepcopy(x_k)
            return x_k
        elif return_player == "y":
            self.policy_y = copy.deepcopy(y_k)
            return y_k

    def evaluation_step(self, T, x_k, N):
        t_hat = np.random.choice(T)
        eta = 1e-3
        state = self.env.reset()
        eps_greedy_x = x_k * (1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        eps_greedy_y = self.policy_y_bar * (1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        for t in range(t_hat):
            for n in range(N):
                index_state = self.get_state_index(state)
                beta_n = 1e-2 / (n + 1)
                action_x = np.random.choice(self.players_budget,
                                            p=eps_greedy_x[:, index_state])
                action_y = np.random.choice(self.players_budget,
                                            p=eps_greedy_y[:, index_state])
                new_state, reward, done = self.env.step(action_x, action_y)
                if not done:
                    index_new_state = self.get_state_index(new_state)
                    next_action_y = np.random.choice(self.players_budget,
                                                     p=eps_greedy_y[:,
                                                       index_new_state])
                    self.evaluation_action_critic[
                        action_y, index_state] -= beta_n * (
                            self.evaluation_action_critic[
                                action_y, index_state]
                            - reward
                            - self.evaluation_action_critic[
                                next_action_y, index_new_state])

                    state = copy.deepcopy(new_state)
                else:
                    self.evaluation_action_critic[
                        action_y, index_state] -= beta_n * (
                            self.evaluation_action_critic[
                                action_y, index_state]
                            - reward)
                    state = self.env.reset()
            for s in range(self.policy_y_bar.shape[1]):
                self.policy_y_bar[:, s] = softmax_update(
                    self.policy_y_bar[:, s],
                    -eta * self.evaluation_action_critic[:, s])
        self.policy_y = copy.deepcopy(self.policy_y_bar)
        return self.policy_y_bar

    def evaluation_step_x(self, T, y_k, N):
        t_hat = np.random.choice(T)
        eta = 1e-3
        state = self.env.reset()
        eps_greedy_x = self.policy_y_bar * (1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        eps_greedy_y = y_k* (
                    1 - self.epsilon) + self.epsilon * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        for t in range(t_hat):
            for n in range(N):
                index_state = self.get_state_index(state)
                beta_n = 1e-2 / (n + 1)
                action_x = np.random.choice(self.players_budget,
                                            p=eps_greedy_x[:, index_state])
                action_y = np.random.choice(self.players_budget,
                                            p=eps_greedy_y[:, index_state])
                new_state, reward, done = self.env.step(action_x, action_y)
                if not done:
                    index_new_state = self.get_state_index(new_state)
                    next_action_x = np.random.choice(self.players_budget,
                                                     p=eps_greedy_x[:,
                                                       index_new_state])
                    self.evaluation_action_critic[
                        action_x, index_state] -= beta_n * (
                            self.evaluation_action_critic[
                                action_x, index_state]
                            - reward
                            - self.evaluation_action_critic[
                                next_action_x, index_new_state])

                    state = copy.deepcopy(new_state)
                else:
                    self.evaluation_action_critic[
                        action_x, index_state] -= beta_n * (
                            self.evaluation_action_critic[
                                action_x, index_state]
                            - reward)
                    state = self.env.reset()
            for s in range(self.policy_y_bar.shape[1]):
                self.policy_y_bar[:, s] = softmax_update(
                    self.policy_y_bar[:, s],
                    eta * self.evaluation_action_critic[:, s])
        self.policy_x = copy.deepcopy(self.policy_y_bar)
        return self.policy_y_bar

    def run(self, K, T, N, return_player="x", scoring=False):
        matrix_T, reward_nash = create_transition_matrix(self.players_budget,
                                                         self.length)
        v_nash = np.zeros((self.players_budget + 1) ** 2 * (self.length + 2)) #Precomputed
        scores = []
        if return_player == "x":
            for k in range(K):
                x_k = self.greedy_step(T, N)
                self.policy_x_ks.append(x_k)
                self.policy_y_ks.append(self.evaluation_step(T, x_k, N))
                if scoring:
                    scores.append(score_policy(self.policy_x_ks[-1],
                                               self.players_budget,
                                               self.length, v_nash))
        elif return_player == "y":
            for k in range(K):
                y_k = self.greedy_step(T, N, return_player="y")
                self.policy_y_ks.append(y_k)
                self.policy_x_ks.append(self.evaluation_step_x(T, y_k, N))
                if scoring:
                    scores.append(score_policy(self.policy_y_ks[-1],
                                               self.players_budget,
                                               self.length, v_nash,
                                               player="y"))

        return np.array(scores)


def projection_simplex_pivot(v, z=1, random_state=None):
    rs = np.random.RandomState(random_state)
    n_features = len(v)
    U = np.arange(n_features)
    s = 0
    rho = 0
    while len(U) > 0:
        G = []
        L = []
        k = U[rs.randint(0, len(U))]
        ds = v[k]
        for j in U:
            if v[j] >= v[k]:
                if j != k:
                    ds += v[j]
                    G.append(j)
            elif v[j] < v[k]:
                L.append(j)
        drho = len(G) + 1
        if s + ds - (rho + drho) * v[k] < z:
            s += ds
            rho += drho
            U = L
        else:
            U = G
    theta = (s - z) / float(rho)
    out = np.maximum(v - theta, 0)
    return out / np.sum(out)


class OGDA:
    def __init__(self, players_budget, length):
        self.eta = 5e-3
        self.players_budget = players_budget
        self.length = length
        self.policy_x = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.policy_y = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.hat_policy_x = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.hat_policy_y = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.env = AlesiaEnv(players_budget, length)
        self.V = np.zeros((players_budget + 1) ** 2 * length)
        self.mc_estimate_x = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # l_t in the paper
        self.mc_estimate_y = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # r_t in the paper
        self.mc_estimate_V = np.zeros(
            (players_budget + 1) ** 2 * length)  # rho in the paper
        self.policies_x = [np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget]
        self.policies_y = [np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget]

    def get_state_index(self, state):

        return np.int(
            (state[0] - 1) * (self.players_budget + 1) ** 2 + state[
                1] * (self.players_budget + 1) + state[2])

    def run(self, T, L, scoring=False):
        v_nash = np.zeros(
            (self.players_budget + 1) ** 2 * (self.length + 2))
        scores_x = []
        scores_y = []
        for t in range(T):
            alpha = 1e-2 / (t + 1)
            self.monte_carlo_evaluate(L)
            self.update_policies(alpha)
            if scoring:
                scores_x.append(score_policy(self.policies_x[-1],
                                             self.players_budget,
                                             self.length, v_nash))
                scores_y.append(score_policy(self.policies_y[-1],
                                             self.players_budget,
                                             self.length, v_nash,
                                             player="y"))

        return np.array(scores_x), np.array(scores_y)

    def update_policies(self, alpha):
        for s in range((self.players_budget + 1) ** 2 * self.length):
            self.hat_policy_x[:, s] = projection_simplex_pivot(
                self.hat_policy_x[:, s] - self.eta * self.mc_estimate_x[
                                                     :, s])
            self.policy_x[:, s] = projection_simplex_pivot(
                self.hat_policy_x[:, s] - self.eta * self.mc_estimate_x[
                                                     :, s])
            self.hat_policy_y[:, s] = projection_simplex_pivot(
                self.hat_policy_y[:, s] + self.eta * self.mc_estimate_y[
                                                     :, s])
            self.policy_y[:, s] = projection_simplex_pivot(
                self.hat_policy_y[:, s] + self.eta * self.mc_estimate_y[
                                                     :, s])

        self.V = (1 - alpha) * self.V + alpha * self.mc_estimate_V

        self.policies_x.append(self.policy_x)
        self.policies_y.append(self.policy_y)

    def monte_carlo_evaluate(self, L):
        counter_x = np.zeros((self.players_budget, (
                self.players_budget + 1) ** 2 * self.length))
        counter_y = np.zeros((self.players_budget, (
                self.players_budget + 1) ** 2 * self.length))
        rewards_x = np.zeros((self.players_budget, (
                self.players_budget + 1) ** 2 * self.length))
        rewards_y = np.zeros((self.players_budget, (
                self.players_budget + 1) ** 2 * self.length))
        self.mc_estimate_V = np.zeros(
            (self.players_budget + 1) ** 2 * self.length)
        eps_greedy_x = self.policy_x * (1 - 0.4) + 0.4 * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        eps_greedy_y = self.policy_y * (1 - 0.4) + 0.4 * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        state = self.env.reset()

        for _ in range(L):
            index_state = self.get_state_index(state)
            action_x = np.random.choice(self.players_budget,
                                        p=eps_greedy_x[:, index_state])
            action_y = np.random.choice(self.players_budget,
                                        p=eps_greedy_y[:, index_state])
            new_state, reward, done = self.env.step(action_x, action_y)
            if not done:
                index_new_state = self.get_state_index(new_state)
                rewards_x[action_x, index_state] = rewards_x[
                                                       action_x, index_state] + reward + \
                                                   self.V[
                                                       index_new_state]
                rewards_y[action_y, index_state] = rewards_y[
                                                       action_y, index_state] + reward + \
                                                   self.V[
                                                       index_new_state]
                counter_x[action_x, index_state] += 1
                counter_y[action_y, index_state] += 1
                state = new_state
            else:
                rewards_x[action_x, index_state] = rewards_x[
                                                       action_x, index_state] + reward
                rewards_y[action_y, index_state] = rewards_y[
                                                       action_y, index_state] + reward
                counter_x[action_x, index_state] += 1
                counter_y[action_y, index_state] += 1
                state = self.env.reset()
            self.mc_estimate_V += (reward + self.V) / L
        counter_x[counter_x == 0] = np.inf
        counter_y[counter_y == 0] = np.inf
        self.mc_estimate_x = rewards_x / counter_x
        self.mc_estimate_y = rewards_y / counter_y


class Reinforce:
    def __init__(self, players_budget, length):
        self.eta = 1e-4
        self.players_budget = players_budget
        self.length = length
        self.policy_x = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.policy_y = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.hat_policy_x = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.hat_policy_y = np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget
        self.env = AlesiaEnv(players_budget, length)
        self.V = np.zeros((players_budget + 1) ** 2 * length)
        self.mc_estimate_x = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # l_t in the paper
        self.mc_estimate_y = np.zeros((players_budget, (
                players_budget + 1) ** 2 * length))  # r_t in the paper
        self.mc_estimate_V = np.zeros(
            (players_budget + 1) ** 2 * length)  # rho in the paper
        self.policies_x = [np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget]
        self.policies_y = [np.ones((players_budget, (
                players_budget + 1) ** 2 * length)) / players_budget]

    def get_state_index(self, state):

        return np.int(
            (state[0] - 1) * (self.players_budget + 1) ** 2 + state[
                1] * (self.players_budget + 1) + state[2])

    def run(self, T, L, scoring=False):
        v_nash = np.zeros(
            (self.players_budget + 1) ** 2 * (self.length + 2))
        scores_x = []
        scores_y = []
        for t in range(T):
            alpha = 1e-4 / (t + 1)  # 1e-2
            self.monte_carlo_evaluate(L)
            self.update_policies(alpha)
            if scoring:
                scores_x.append(score_policy(self.policies_x[-1],
                                             self.players_budget,
                                             self.length, v_nash))
                scores_y.append(score_policy(self.policies_y[-1],
                                             self.players_budget,
                                             self.length, v_nash,
                                             player="y"))

        return np.array(scores_x), np.array(scores_y)

    def update_policies(self, alpha):

        for s in range((self.players_budget + 1) ** 2 * self.length):
            self.policy_x[:, s] = projection_simplex_pivot(
                self.policy_x[:, s] - self.eta * self.mc_estimate_x[
                                                 :, s])

            self.policy_y[:, s] = projection_simplex_pivot(
                self.policy_y[:, s] + self.eta * self.mc_estimate_y[
                                                 :, s])

        self.policies_x.append(self.policy_x)
        self.policies_y.append(self.policy_y)

    def monte_carlo_evaluate(self, L):

        rewards_x = np.zeros((self.players_budget, (
                self.players_budget + 1) ** 2 * self.length))
        rewards_y = np.zeros((self.players_budget, (
                self.players_budget + 1) ** 2 * self.length))

        eps_greedy_x = self.policy_x * (1 - 0.1) + 0.1 * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        eps_greedy_y = self.policy_y * (1 - 0.1) + 0.1 * np.ones(
            (self.players_budget,
             (
                     self.players_budget + 1) ** 2 * self.length)) / self.players_budget
        state = self.env.reset()

        for _ in range(L):
            index_state = self.get_state_index(state)
            action_x = np.random.choice(self.players_budget,
                                        p=eps_greedy_x[:, index_state])
            action_y = np.random.choice(self.players_budget,
                                        p=eps_greedy_y[:, index_state])
            new_state, reward, done = self.env.step(action_x, action_y)
            if not done:
                index_new_state = self.get_state_index(new_state)
                rewards_x[action_x, index_state] = rewards_x[
                                                       action_x, index_state] + reward / \
                                                   eps_greedy_x[
                                                       action_x, index_state]
                rewards_y[action_y, index_state] = rewards_y[
                                                       action_y, index_state] + reward / \
                                                   eps_greedy_y[
                                                       action_y, index_state]

                state = new_state
            else:
                rewards_x[action_x, index_state] = rewards_x[
                                                       action_x, index_state] + reward / \
                                                   eps_greedy_x[
                                                       action_x, index_state]
                rewards_y[action_y, index_state] = rewards_y[
                                                       action_y, index_state] + reward / \
                                                   eps_greedy_y[
                                                       action_y, index_state]
                state = self.env.reset()
        self.mc_estimate_x = rewards_x
        self.mc_estimate_y = rewards_y


def run_alesia_experiment(players_budget, length, seeds=np.arange(1),
                          player="x"):
    scores_reflected = []
    scores_ogda_x = []
    scores_ogda_y = []
    scores_reinforce_x = []
    scores_reinforce_y = []

    for seed in seeds:
        print(seed)
        np.random.seed(seed)

        reflected_nac = ReflectedNAC(players_budget=players_budget,
                                     length=length)
        ogda = OGDA(players_budget=players_budget, length=length)
        reinforce = Reinforce(players_budget=players_budget, length=length)
        scores_reflected.append(
            np.array(
                reflected_nac.run(
                    K=5, T=100, N=70,  # 100 100 10
                    return_player=player, scoring=True)))
        score_x, score_y = ogda.run(10, 3500, scoring=True)
        scores_ogda_x.append(score_x)
        scores_ogda_y.append(score_y)
        score_x_reinforce, score_y_reinforce = reinforce.run(10, 3500, scoring=True)
        scores_reinforce_x.append(score_x_reinforce)
        scores_reinforce_y.append(score_y_reinforce)
    return (scores_reflected,
            scores_ogda_x,
            scores_ogda_y,
            scores_reinforce_x,
            scores_reinforce_y)


parser = argparse.ArgumentParser()
parser.add_argument("--length", type=int)
parser.add_argument("--budget", type=int)
parser.add_argument("--savename", type=str)
parser.add_argument("--player", type=str, default="x")
args = parser.parse_args()
results = run_alesia_experiment(players_budget=args.budget,
                                length=args.length, player=args.player)

with open(args.savename + ".pkl", "wb") as f:
    pickle.dump(results, f)
